package com.gitee.fastmybatis.core.support.plugin;

import com.gitee.fastmybatis.core.FastmybatisConfig;
import org.apache.ibatis.logging.Log;
import org.apache.ibatis.logging.LogFactory;
import org.apache.ibatis.plugin.Invocation;

import java.lang.reflect.Field;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Proxy;
import java.sql.PreparedStatement;
import java.util.Properties;

/**
 * SQL格式化默认实现
 * @author thc
 */
public class DefaultSqlFormatterHandler implements SqlFormatterHandler {

    protected static final Log LOG = LogFactory.getLog(DefaultSqlFormatterHandler.class);

    private final SqlFormatter sqlFormatter = new SqlFormatter();

    private Properties properties;

    /**
     * 打印SQL末尾是否自动添加结束符（;）
     */
    private boolean appendDelimiter = true;

    /**
     * 结束符
     */
    private String delimiter = ";";

    /**
     * 是否使用Druid
     */
    private boolean useDruid;

    private String format = "\n====== SQL ======\n%s";

    private boolean enable;

    @Override
    public boolean canPrint(Invocation invocation) {
        return this.enable || FastmybatisConfig.PRINT_SQL || LOG.isDebugEnabled();
    }

    @Override
    public String getSql(Invocation invocation) throws Exception {
        PreparedStatement statement = (PreparedStatement) invocation.getArgs()[0];
        if (Proxy.isProxyClass(statement.getClass())) {
            InvocationHandler handler = Proxy.getInvocationHandler(statement);
            if (handler.getClass().getName().endsWith(".PreparedStatementLogger")) {
                Field field = handler.getClass().getDeclaredField("statement");
                field.setAccessible(true);
                statement = (PreparedStatement) field.get(handler);
                //原文参考：https://www.zhangshengrong.com/p/OgN5DgLDan/
                //以下应该是使用了 druid 原因 ，没有使用的话下面几行代码请注释掉
                if (isUseDruid()) {
                    Field stmt = statement.getClass().getDeclaredField("stmt");
                    stmt.setAccessible(true);
                    statement = (PreparedStatement) stmt.get(statement);
                    Field clientStatement = statement.getClass().getDeclaredField("statement");
                    clientStatement.setAccessible(true);
                    statement = (PreparedStatement) clientStatement.get(statement);
                }
            }
        }
        String statementSql = statement.toString();
        String sql = statementSql
                .replace("** BYTE ARRAY DATA **", "null")
                .substring(statementSql.indexOf(":") + 1);
        String endMark = getDelimiter();
        if (isAppendDelimiter() && !sql.endsWith(endMark)) {
            sql = sql + endMark;
        }
        return sql;
    }

    @Override
    public String formatSql(String sql) {
        String fullSql = getSqlFormatter().format(sql);
        // 自定义格式化
        return String.format(getFormat(), fullSql);
    }

    @Override
    public void printSql(String sql) {
        LOG.warn(sql);
    }

    public SqlFormatter getSqlFormatter() {
        return sqlFormatter;
    }

    public Properties getProperties() {
        return properties;
    }

    public boolean isAppendDelimiter() {
        return appendDelimiter;
    }

    public void setAppendDelimiter(boolean appendDelimiter) {
        this.appendDelimiter = appendDelimiter;
    }

    public String getDelimiter() {
        return delimiter;
    }

    public void setDelimiter(String delimiter) {
        this.delimiter = delimiter;
    }

    public boolean isUseDruid() {
        return useDruid;
    }

    public void setUseDruid(boolean useDruid) {
        this.useDruid = useDruid;
    }

    public String getFormat() {
        return format;
    }

    public void setFormat(String format) {
        this.format = format;
    }

    @Override
    public void setProperties(Properties properties) {
        this.properties = properties;
        this.appendDelimiter = Boolean.parseBoolean(properties.getProperty("appendDelimiter", "true"));
        this.enable = Boolean.parseBoolean(properties.getProperty("enable", "false"));
        this.useDruid = Boolean.parseBoolean(properties.getProperty("useDruid", "false"));
        this.delimiter = properties.getProperty("delimiter", ";");
        this.format = properties.getProperty("format", "\n====== SQL ======\n%s").replace("\\n", "\n");
    }
}
